mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
clean up, reformating will come later
This commit is contained in:
@@ -30,8 +30,7 @@ class ModelArgs(BaseModelArgs):
|
||||
time_step_floor: float
|
||||
rescale_prenorm_residual: bool
|
||||
use_cache: bool
|
||||
pscan: bool = False # use parallel scan mode or sequential mode when training
|
||||
use_mambapy: bool = False
|
||||
pscan: bool = False
|
||||
tie_word_embeddings: bool = True
|
||||
|
||||
|
||||
@@ -102,13 +101,6 @@ def clamp(x, min=None, max=None):
|
||||
|
||||
|
||||
def pscan_f(A, X):
|
||||
# A : (B, D, L, N)
|
||||
# X : (B, D, L, N)
|
||||
|
||||
# modifies X in place by doing a parallel scan.
|
||||
# more formally, X will be populated by these values: H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
|
||||
# which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)
|
||||
|
||||
Aa = A
|
||||
Xa = X
|
||||
|
||||
@@ -157,10 +149,10 @@ def pscan_f(A, X):
|
||||
A[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Aa[:, :, :, 0], mx.array([last_val_aa]).reshape(B, D, 1, -1)], axis=2)
|
||||
X[:, :, 2**k-1::2**(k+1)] = mx.concatenate([Xa[:, :, :, 0], mx.array([last_val_xa]).reshape(B, D, 1, -1)], axis=2)
|
||||
|
||||
# main function, used in the Mamba model (mamba_mlx.py)
|
||||
|
||||
def pscan(A_in, X_in):
|
||||
"""
|
||||
Applies the parallel scan operation, as defined above. Returns a new tensor.
|
||||
Applies the parallel scan operation, as defined above. Returns a new array.
|
||||
|
||||
Args:
|
||||
A_in: mx.array =-> Shape(B, L, ED, N)
|
||||
@@ -169,12 +161,9 @@ def pscan(A_in, X_in):
|
||||
Returns:
|
||||
H: mx.array -> Shape (B, L, ED, N)
|
||||
"""
|
||||
|
||||
A = A_in[:].transpose(0, 2, 1, 3)
|
||||
X = X_in[:].transpose(0, 2, 1, 3)
|
||||
|
||||
pscan_f(A, X)
|
||||
|
||||
return X.transpose(0, 2, 1, 3)
|
||||
|
||||
|
||||
@@ -223,7 +212,6 @@ class MambaBlock(nn.Module):
|
||||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
|
||||
|
||||
def ssm_step(self, x, ssm_state=None):
|
||||
# Modify this method to work without state during training
|
||||
A = -mx.exp(self.A_log) # (ED, N)
|
||||
D = self.D # (ED,)
|
||||
|
||||
@@ -237,7 +225,6 @@ class MambaBlock(nn.Module):
|
||||
BX = deltaB * mx.expand_dims(x, -1) # (B, ED, N)
|
||||
|
||||
if self.training:
|
||||
# During training, we don't use or update the state
|
||||
new_ssm_state = BX
|
||||
else:
|
||||
if ssm_state is None:
|
||||
@@ -255,7 +242,6 @@ class MambaBlock(nn.Module):
|
||||
|
||||
def ssm(self, x):
|
||||
# x : (B, L, ED)
|
||||
|
||||
# y : (B, L, ED)
|
||||
|
||||
A = -mx.exp(self.A_log) # (ED, N)
|
||||
@@ -280,7 +266,6 @@ class MambaBlock(nn.Module):
|
||||
# B : (B, L, N)
|
||||
# C : (B, L, N)
|
||||
# D : (ED)
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
|
||||
deltaB = mx.expand_dims(delta, -1) * mx.expand_dims(B, 2) # (B, L, ED, N)
|
||||
|
||||
@@ -300,7 +285,6 @@ class MambaBlock(nn.Module):
|
||||
# B : (B, L, N)
|
||||
# C : (B, L, N)
|
||||
# D : (ED)
|
||||
|
||||
_, L, _ = x.shape
|
||||
|
||||
deltaA = mx.exp(mx.expand_dims(delta, -1) * A) # (B, L, ED, N)
|
||||
@@ -366,7 +350,6 @@ class ResidualBlock(nn.Module):
|
||||
self.norm = nn.RMSNorm(args.hidden_size)
|
||||
|
||||
def __call__(self, x: mx.array, cache: MambaCache, layer_idx: int):
|
||||
# Ensure x is 3D before passing to mixer
|
||||
if x.ndim == 2:
|
||||
x = mx.expand_dims(x, 1) # Make it (B, 1, D)
|
||||
|
||||
|
Reference in New Issue
Block a user